-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Implement new experimental lookup-based matrix multiplication method(TMAC) #26695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
…as kernel not implemented for fp32. Also, I need to write the packing logic for the scales as well.
…ssert issue with the data shuffling in prepack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
…unction signature
…le group size validation
include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
Outdated
Show resolved
Hide resolved
include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
Outdated
Show resolved
Hide resolved
| } | ||
|
|
||
| // Create a temporary threadpool for parallel packing | ||
| // This is used during model load time to speed up weight prepacking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the overhead like for creating a new threadpool in each call to PrePack()?
I wonder if we should make an existing threadpool available to this code. perhaps we can pass in the threadpool from SessionState. something to consider, and maybe for a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, passing thread pool to PrePack would be clean. I am planning to create second PR improving Prepacking logic in general, I will include this along with this :)
| auto scale_ptr = scales ? scales->DataRaw() : nullptr; | ||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
| MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, | ||
| has_zp_input_, nullptr, threadpool_ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC - The usage of threadpool in the existing non-LUT path seems like a new addition - is that intentaional (and come with apprioriate tests) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially, I thought tests in test_sqnbitgemm.cpp should suffice since they already test it with thread pool. I applied changes to only use thread pool for LUT path now.
Once we add tests, I think it might be beneficial to use thread pool for pre packing for other paths
Description
This PR introduces a new experimental lookup-table(LUT) based matrix multiplication method for 2-bit MatMulNBits on x64 AVX2 inspired from T-MAC paper and T-MAC repository to speed up low bit LLM inference.
Unlike the existing quant-dequant methods, the LUT-based method directly supports mixed-precision-GEMM without dequantization. It uses bit-wise table lookup to eliminate multiplications and reduce additions required in matrix multiplication.
This PR:
mlas.use_lut_gemmsession option allowing use of LUT GEMM inside matmulnbits when it is available (2-bit, BlkLen multiple of 32, K multiple of 32, N multiple of 128, AVX2 present).MlasLUTGemmentry that generates per-row LUTs and calls the AVX2 kernel.GenerateLUT_avx2and GEMM computeTMACComputeGemm_avx2and wires dispatch in MLAS platform init.Main components:
MlasInitLUTGemmKernelConfig: Config for LUT kernelsMlasLUTGemmPackQuantBData: Pre Packing of quantized weightMlasLUTPackScalesAndZeroPoints: Pre Packing of qunatized scales and zero pointsMlasLUTGemm: Main Entry pointGenerateLUT_avx2: LUT construction from activationsTMACComputeGemm_avx2: AVX2 LUT GEMM kernelSession option: mlas.use_lut_gemm
How to test
test_sqlutgemm.cppmlas.use_lut_gemm=1on AVX2 machines; expect fallback to existing path if availability checks fail.Perf
Focus of this PR is functional + kernel bring-up; perf to be reported separately once broader profiling is done.
Future Work